#!/usr/bin/env python3
#
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""
HRW4U Knowledge Graph Builder

Extract knowledge graph data from hrw4u configuration files and test data.
Generates structured JSON output suitable for graph database import.
"""

import argparse
import json
import logging
import sys
from pathlib import Path
from typing import Iterator

from antlr4 import InputStream, CommonTokenStream
from antlr4.error.ErrorListener import ErrorListener

from hrw4u.hrw4uLexer import hrw4uLexer
from hrw4u.hrw4uParser import hrw4uParser
from hrw4u.kg_visitor import KnowledgeGraphVisitor, KGData
from hrw4u.errors import ErrorCollector

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%H:%M:%S')
logger = logging.getLogger(__name__)


class SyntaxErrorListener(ErrorListener):
    """Capture parse errors for reporting."""

    def __init__(self):
        super().__init__()
        self.errors = []

    def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
        self.errors.append(f"Line {line}:{column} - {msg}")


def discover_hrw4u_files(paths: list[Path]) -> Iterator[Path]:
    """Discover all hrw4u files from given paths, excluding test failure cases."""
    for path in paths:
        if path.is_file():
            # Skip files with "fail" in the name (expected to fail parsing)
            if "fail" in path.name.lower():
                logger.debug(f"Skipping failure test case: {path}")
                continue
            if path.suffix in {'.hrw4u', '.input.txt'} or path.name.endswith('.input.txt'):
                yield path
        elif path.is_dir():
            # Recursively find hrw4u files, excluding failure cases
            for pattern in ['**/*.hrw4u', '**/*.input.txt']:
                for file_path in path.glob(pattern):
                    if "fail" not in file_path.name.lower():
                        yield file_path
                    else:
                        logger.debug(f"Skipping failure test case: {file_path}")
        else:
            logger.warning(f"Path not found: {path}")


def parse_hrw4u_file(filepath: Path) -> tuple[hrw4uParser.ProgramContext | None, list[str]]:
    """Parse a single hrw4u file and return AST and any errors."""
    try:
        content = filepath.read_text(encoding='utf-8')

        # Create lexer and parser
        input_stream = InputStream(content)
        lexer = hrw4uLexer(input_stream)
        token_stream = CommonTokenStream(lexer)
        parser = hrw4uParser(token_stream)

        # Add error listener
        error_listener = SyntaxErrorListener()
        parser.removeErrorListeners()
        parser.addErrorListener(error_listener)

        # Parse program
        ast = parser.program()

        return ast, error_listener.errors

    except Exception as e:
        return None, [f"Failed to parse {filepath}: {e}"]


def extract_kg_data(filepath: Path, debug: bool = False) -> KGData | None:
    """Extract knowledge graph data from a single file."""
    logger.debug(f"Processing {filepath}")

    ast, errors = parse_hrw4u_file(filepath)

    if errors:
        logger.error(f"Parse errors in {filepath}:")
        for error in errors:
            logger.error(f"  {error}")
        return None

    if not ast:
        logger.error(f"Failed to parse {filepath}")
        return None

    try:
        error_collector = ErrorCollector()
        visitor = KnowledgeGraphVisitor(filename=str(filepath), debug=debug, error_collector=error_collector)

        kg_data = visitor.visitProgram(ast)
        if error_collector.has_errors():
            logger.warning(f"Visitor warnings for {filepath}:")
            for error in error_collector.get_errors():
                logger.warning(f"  {error}")

        logger.debug(f"Extracted {len(kg_data.nodes)} nodes, {len(kg_data.edges)} edges from {filepath}")
        return kg_data

    except Exception as e:
        logger.error(f"Failed to extract KG data from {filepath}: {e}")
        return None


def merge_kg_data(kg_data_list: list[KGData]) -> KGData:
    """Merge multiple KG data structures into one."""
    if not kg_data_list:
        return KGData(nodes=[], edges=[])

    result = kg_data_list[0]
    for kg_data in kg_data_list[1:]:
        result = result.merge(kg_data)

    return result


def export_json(kg_data: KGData, output_path: Path, pretty: bool = True) -> None:
    """Export KG data to JSON file."""
    output_path.parent.mkdir(parents=True, exist_ok=True)

    data = kg_data.to_dict()

    # Add metadata
    data['metadata'] = {
        'version': '1.0',
        'node_count': len(kg_data.nodes),
        'edge_count': len(kg_data.edges),
        'node_types': list({node.type for node in kg_data.nodes}),
        'relationship_types': list({edge.relationship for edge in kg_data.edges})
    }

    with output_path.open('w', encoding='utf-8') as f:
        if pretty:
            json.dump(data, f, indent=2, ensure_ascii=False)
        else:
            json.dump(data, f, ensure_ascii=False)

    logger.info(f"Exported KG data to {output_path}")
    logger.info(f"  Nodes: {len(kg_data.nodes)}")
    logger.info(f"  Edges: {len(kg_data.edges)}")
    logger.info(f"  Node types: {len(data['metadata']['node_types'])}")
    logger.info(f"  Relationship types: {len(data['metadata']['relationship_types'])}")


def export_stats(kg_data: KGData, output_path: Path = None) -> None:
    """Export KG statistics."""
    stats = {
        'total_nodes': len(kg_data.nodes),
        'total_edges': len(kg_data.edges),
        'node_types': {},
        'relationship_types': {},
        'top_connected_nodes': []
    }

    # Count node types
    for node in kg_data.nodes:
        stats['node_types'][node.type] = stats['node_types'].get(node.type, 0) + 1

    # Count relationship types
    for edge in kg_data.edges:
        stats['relationship_types'][edge.relationship] = stats['relationship_types'].get(edge.relationship, 0) + 1

    # Find most connected nodes
    node_connections = {}
    for edge in kg_data.edges:
        node_connections[edge.source_id] = node_connections.get(edge.source_id, 0) + 1
        node_connections[edge.target_id] = node_connections.get(edge.target_id, 0) + 1

    # Get top 10 most connected nodes
    top_nodes = sorted(node_connections.items(), key=lambda x: x[1], reverse=True)[:10]
    for node_id, count in top_nodes:
        node = next((n for n in kg_data.nodes if n.id == node_id), None)
        if node:
            stats['top_connected_nodes'].append(
                {
                    'id': node_id,
                    'type': node.type,
                    'connections': count,
                    'properties': node.properties
                })

    if output_path:
        with output_path.open('w', encoding='utf-8') as f:
            json.dump(stats, f, indent=2, ensure_ascii=False)
        logger.info(f"Exported statistics to {output_path}")
    else:
        print("\n=== Knowledge Graph Statistics ===")
        print(f"Total nodes: {stats['total_nodes']}")
        print(f"Total edges: {stats['total_edges']}")
        print("\nNode types:")
        for node_type, count in sorted(stats['node_types'].items()):
            print(f"  {node_type}: {count}")
        print("\nRelationship types:")
        for rel_type, count in sorted(stats['relationship_types'].items()):
            print(f"  {rel_type}: {count}")
        if stats['top_connected_nodes']:
            print("\nMost connected nodes:")
            for node_info in stats['top_connected_nodes'][:5]:
                print(
                    f"  {node_info['type']} ({node_info['connections']} connections): {node_info.get('properties', {}).get('name', node_info['id'])}"
                )


def main() -> int:
    """Main entry point."""
    parser = argparse.ArgumentParser(
        description='Extract knowledge graph data from hrw4u files',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Process test data directory
  hrw4u_kg tests/data/ -o kg_data.json

  # Process specific files
  hrw4u_kg file1.hrw4u file2.hrw4u -o combined.json

  # Process with debug output and statistics
  hrw4u_kg tests/data/ -o kg.json --stats stats.json --debug

  # Just show statistics without saving
  hrw4u_kg tests/data/ --show-stats
""")

    parser.add_argument('paths', nargs='+', type=Path, help='Input paths (files or directories) to process')
    parser.add_argument('-o', '--output', type=Path, help='Output JSON file for KG data')
    parser.add_argument('--stats', type=Path, help='Output JSON file for statistics')
    parser.add_argument('--show-stats', action='store_true', help='Display statistics to console')
    parser.add_argument('--debug', action='store_true', help='Enable debug logging')
    parser.add_argument('--compact', action='store_true', help='Use compact JSON output (no pretty printing)')
    args = parser.parse_args()

    if args.debug:
        logging.getLogger().setLevel(logging.DEBUG)

    input_files = list(discover_hrw4u_files(args.paths))

    if not input_files:
        logger.error("No hrw4u files found in specified paths")
        return 1

    logger.info(f"Found {len(input_files)} hrw4u files to process")
    kg_data_list = []
    processed_count = 0

    for filepath in input_files:
        kg_data = extract_kg_data(filepath, debug=args.debug)
        if kg_data:
            kg_data_list.append(kg_data)
            processed_count += 1
        else:
            logger.warning(f"Skipped {filepath} due to errors")

    if not kg_data_list:
        logger.error("No valid KG data extracted from any files")
        return 1

    logger.info(f"Successfully processed {processed_count}/{len(input_files)} files")
    merged_kg_data = merge_kg_data(kg_data_list)

    if args.output:
        export_json(merged_kg_data, args.output, pretty=not args.compact)

    if args.stats:
        export_stats(merged_kg_data, args.stats)

    if args.show_stats:
        export_stats(merged_kg_data)

    if not args.output and not args.stats and not args.show_stats:
        logger.warning("No output specified. Use -o, --stats, or --show-stats to see results.")
        export_stats(merged_kg_data)  # Show stats by default

    return 0


if __name__ == '__main__':
    sys.exit(main())
